from typing import TypedDict, Dict, List, Annotated, Union, Literal, Optional
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
import json
from core.testAgent.EnvironmentService import EnvironmentService
from langchain_core.messages import HumanMessage, BaseMessage, ToolMessage
import re
from core.testAgent.RetrieverTools import (find_class, find_method_definition, find_variable_definition, find_method_calls,
                            find_method_usages, fuzzy_search, search_similarity_test_class, human_assistance, graph_retriever,
                            bind_tool_job_context)
from core.testAgent.PromptTemplate import Generator_Init_State_Template, Execution_Review_Template, \
    Generator_Init_State_with_Function_Information_Template, Generator_Update_Init_State_Template, \
    Generator_Update_Init_State_with_Function_Information_Template
from langgraph.pregel import RetryPolicy
from langchain.tools.render import render_text_description
from core.testAgent.Config import *
from core.testAgent.Utils import extract_code, extract_json
from core.testAgent.ProgressTracker import progress_tracker

# 工具定义
if Limit_Retrieve_Test_Case:
    tools = [find_class,
             find_method_definition,
             find_variable_definition,
             find_method_calls,
             find_method_usages,
             fuzzy_search,
             human_assistance]
else:
    tools = [find_class,
             find_method_definition,
             find_variable_definition,
             find_method_calls,
             find_method_usages,
             fuzzy_search,
             search_similarity_test_class,
             human_assistance]
tools_without_human = [tool for tool in tools if getattr(tool, "name", "") != "human_assistance"]
# tool_node = ToolNode(tools)

# LLM初始化，在`Config.py`中定义
if Enable_Native_Function_Call:
    llm_with_tools = llm.bind_tools(tools)
    # llm_with_tools = llm  # 此处做消融实验
    llm_without_human_tools = llm.bind_tools(tools_without_human)
else:
    llm_with_tools = llm
    llm_without_human_tools = llm
    rendered_tools = render_text_description(tools)
    rendered_tools_without_human = render_text_description(tools_without_human)
if Enable_Native_Function_Call:
    rendered_tools_without_human = ""


def _get_active_tools(allow_human: bool):
    return tools if allow_human else tools_without_human


def _get_llm_runner(allow_human: bool):
    return llm_with_tools if allow_human else llm_without_human_tools


def _get_rendered_tools_text(allow_human: bool):
    return rendered_tools if allow_human else rendered_tools_without_human


# 状态定义
class GeneratorState(TypedDict):
    envServer: EnvironmentService  # 环境服务
    allow_human_assistance: bool  # 是否允许调用人工协助
    feedback_times: int  # 反馈次数
    max_feedback_times: int  # 最大反馈次数
    find_bug: bool  # 是否找到 bug
    bug_report: str  # bug报告
    package_name: str  # 包名
    method_code: str  # 被测函数代码
    method_signature: str  # 被测函数签名
    class_name: str  # 所在类名
    full_method_name: str  # 完整方法名
    start_line: int  # 被测函数起始行
    end_line: int  # 被测函数结束行
    method_summary: str  # 函数的主要功能/作用总结
    requirement: Dict  # 生成的需求列表
    test_case: str  # 生成的测试用例
    test_class_name: str  # 测试类名
    compile_result: bool  # 编译结果
    execute_result: bool  # 执行结果
    test_result: str  # 测试结果
    test_report: str  # 测试报告
    coverage_report: str  # 覆盖率报告
    mutation_report: str  # 变异测试报告
    # 交互历史（支持 LLM 处理）
    messages: Annotated[List[Union[dict, BaseMessage]], add_messages]
    log_message: list  # 日志消息
    generate_or_evaluation: Literal["generation", "evaluation"]  # 生成或评估
    job_id: Optional[str]  # 进度跟踪 ID


def generator_init_state(state: GeneratorState):
    """
    初始化状态
    """
    allow_human = state.get("allow_human_assistance", True)
    requirement = state["requirement"]
    requirement.pop("test_case", None)
    test_class_name = ""
    if state["test_case"] == "":

        prompt_information = {
            "class_name": state["class_name"],
            "package_name": state["package_name"],
            "method_signature": state["method_signature"],
            "method_summary": state["method_summary"],
            "test_specification": json.dumps(requirement, indent=4),
        }
        if Enable_Native_Function_Call:
            init_prompt = Generator_Init_State_Template.invoke(prompt_information)
        else:
            # 如果LLM无原生function call功能，则在system prompt中加入tools的描述
            prompt_information["rendered_tools"] = _get_rendered_tools_text(allow_human)
            init_prompt = Generator_Init_State_with_Function_Information_Template.invoke(prompt_information)
    else:
        class_pattern = r'public class\s+(\w+)\s*(?:\{|\s+extends|\s+implements|$)'
        matches = re.findall(class_pattern, state["test_case"])
        test_class_name = matches[0] if matches else ""
        prompt_information = {
            "class_name": state["class_name"],
            "package_name": state["package_name"],
            "method_signature": state["method_signature"],
            "method_summary": state["method_summary"],
            "test_specification": json.dumps(requirement, indent=4),
            "test_case": state["test_case"],
        }
        if Enable_Native_Function_Call:
            init_prompt = Generator_Update_Init_State_Template.invoke(prompt_information)
        else:
            # 如果LLM无原生function call功能，则在system prompt中加入tools的描述
            prompt_information["rendered_tools"] = _get_rendered_tools_text(allow_human)
            init_prompt = Generator_Update_Init_State_with_Function_Information_Template.invoke(prompt_information)
    valid_prompt = init_prompt.to_messages()
    return {"messages": valid_prompt, "feedback_times": 0, "max_feedback_times": 3, "find_bug": False,
            "log_message": [], "compile_result": False, "execute_result": False, "test_case": None,
            "test_class_name": test_class_name, "bug_report": "", "allow_human_assistance": state.get("allow_human_assistance", True)}


def testMethodGenerator(state: GeneratorState):
    """
    **状态节点**：测试用例生成
    Executor状态机的核心状态，给定messages对话流中的信息，进行模型调用，对返回结果做解析和处理
    包括调用工具、生成测试用例、不合规内容反馈等
    """
    allow_human = state.get("allow_human_assistance", True)
    llm_runner = _get_llm_runner(allow_human)
    active_tools = _get_active_tools(allow_human)
    result = llm_runner.invoke(state["messages"])
    tools_by_name = {tool.name: tool for tool in active_tools}
    state["messages"].append(result)
    log = []
    try_times = 0
    max_try_times = 3
    while try_times < max_try_times:
        if Enable_Native_Function_Call and result.tool_calls:
            for tool_call in result.tool_calls:
                if tool_call["name"] not in tools_by_name:
                    feedback = HumanMessage(
                        f"Tool {tool_call['name']} not found. please check the tool name and try again."
                        f"Available tools are: {', '.join(tools_by_name.keys())}"
                    )
                    state["messages"].append(feedback)
                    log.append(result)
                    log.append(feedback)
                    result = llm_runner.invoke(state["messages"])
                    state["messages"].append(result)
                    try_times += 1
                    continue
            break
        elif not Enable_Native_Function_Call:
            expected_schema = {"name": str, "arguments": dict}
            ret, intention = extract_json(result.content, expected_schema)
            check_result = False
            if type(ret) is list:
                for r in ret:
                    if r["name"] not in tools_by_name:
                        break
                check_result = True
            if intention and not check_result:
                feedback = HumanMessage(
                    "If you want to use tools, please specify the tool name and arguments in the JSON format. "
                    "For example, ```json\n{'name': 'find_class', 'arguments': {'class_name': 'Test'}}\n```"
                    "Available tools are: " + ', '.join(tools_by_name.keys()) + "."
                                                                                "Please try again."
                )
                state["messages"].append(feedback)
                log.append(result)
                log.append(feedback)
                result = llm_runner.invoke(state["messages"])
                state["messages"].append(result)
                try_times += 1
                continue
            break
        else:
            break

    return {"messages": [result], "log_message": log}


def call_tool(state: GeneratorState):
    """
    **状态节点**：工具调用
    调用工具，包括支持原生function call的LLM和不支持原生function call的LLM，返回工具调用结果
    """
    allow_human = state.get("allow_human_assistance", True)
    tools_by_name = {tool.name: tool for tool in _get_active_tools(allow_human)}
    messages = state["messages"]
    last_message = messages[-1]
    output_messages = []
    job_id = state.get("job_id")
    bind_tool_job_context(job_id)
    allow_human = state.get("allow_human_assistance", True)
    if Enable_Native_Function_Call:
        for tool_call in last_message.tool_calls:
            tool_name = tool_call["name"]
            arguments = tool_call.get("args") or {}
            feedback_times = state.get("feedback_times", 0)
            if tool_name == "human_assistance" and feedback_times < 2:
                output_messages.append(
                    ToolMessage(
                        content=json.dumps({"status": "error",
                                            "message": "请至少完成两次自动修复尝试后再请求人类协助。"}),
                        name=tool_name,
                        tool_call_id=tool_call["id"],
                    )
                )
                progress_tracker.record_tool(job_id, tool_name, status="error", arguments=arguments,
                                              detail="insufficient attempts before human assistance")
                continue
            if tool_name == "human_assistance" and not allow_human:
                output_messages.append(
                    ToolMessage(
                        content=json.dumps({"error": "human assistance disabled"}),
                        name=tool_name,
                        tool_call_id=tool_call["id"],
                    )
                )
                progress_tracker.record_tool(job_id, tool_name, status="error", arguments=arguments,
                                              detail="human assistance disabled")
                continue
            try:
                if tool_name == "search_similarity_test_class":
                    arguments = state["test_class_name"]
                    tool_result = tools_by_name[tool_name].invoke(arguments)
                else:
                    tool_result = tools_by_name[tool_name].invoke(arguments)
                output_messages.append(
                    ToolMessage(
                        content=json.dumps(tool_result),
                        name=tool_name,
                        tool_call_id=tool_call["id"],
                    )
                )
                progress_tracker.record_tool(job_id, tool_name, arguments=arguments)
            except Exception as e:
                # Return the error if the tool call fails
                output_messages.append(
                    ToolMessage(
                        content=f"error: {e}",
                        name=tool_name,
                        tool_call_id=tool_call["id"],
                        additional_kwargs={"error": e},
                    )
                )
                progress_tracker.record_tool(job_id, tool_name, status="error", arguments=arguments, detail=str(e))
    else:
        expected_schema = {"name": str, "arguments": dict}
        results, intention = extract_json(last_message.content, expected_schema)
        assert results, "Tool call not found. Generally it is impossible to reach here."
        for tool_call in results:
            tool_name = tool_call["name"]
            arguments = tool_call["arguments"]
            feedback_times = state.get("feedback_times", 0)
            if tool_name == "human_assistance" and feedback_times < 2:
                output_messages.append(
                    HumanMessage("请先自行尝试至少两次修复后再请求人类协助。")
                )
                progress_tracker.record_tool(job_id, tool_name, status="error", arguments=arguments,
                                              detail="insufficient attempts before human assistance")
                continue
            if tool_name == "human_assistance" and not allow_human:
                output_messages.append(
                    HumanMessage("Human assistance 已关闭，请在“自主完成”模式下继续执行任务。")
                )
                progress_tracker.record_tool(job_id, tool_name, status="error", arguments=arguments,
                                              detail="human assistance disabled")
                continue
            try:
                if tool_name == "search_similarity_test_class":
                    arguments = state["test_class_name"]
                    tool_result = tools_by_name[tool_name].invoke(arguments)
                else:
                    tool_result = tools_by_name[tool_name].invoke(arguments)
                output_messages.append(
                    HumanMessage(
                        f"Tool call: {tool_name}, arguments: {arguments}, result: {json.dumps(tool_result)}"
                    )
                )
                progress_tracker.record_tool(job_id, tool_name, arguments=arguments)
            except Exception as e:
                # Return the error if the tool call fails
                output_messages.append(
                    HumanMessage(
                        f"Tool call: {tool_name}, arguments: {arguments}, error: {e}"
                    )
                )
                progress_tracker.record_tool(job_id, tool_name, status="error", arguments=arguments, detail=str(e))
    return {"messages": output_messages}


def generatorChecker(state: GeneratorState):
    """
    **分支判断**
    根据‘测试用例生成’节点的输出结果，判断下一步的执行路径
    如果返回结果中包含工具调用，则进入工具调用节点
    否则进入代码提取节点
    """
    messages = state["messages"]
    last_message = messages[-1]
    if Enable_Native_Function_Call:
        if last_message.tool_calls:
            return "tools"
        return "codeExtractor"
    else:
        expected_schema = {"name": str, "arguments": dict}
        result, intention = extract_json(last_message.content, expected_schema)
        if result:
            return "tools"
        return "codeExtractor"


def codeExtractor(state: GeneratorState):
    """
    **状态节点**：代码提取
    从对话流中提取代码，判断代码是否可运行，包括是否包含必要的包、import语句和类定义
    """
    envServer = state["envServer"]
    last_message = state["messages"][-1]
    test_case = extract_code(last_message.content)
    if not test_case:
        return {"messages": [HumanMessage(
            "Have you fully understood the tested method along with its contextual dependencies? \n"
            "If yes, please proceed to generate the JUnit test case in a Markdown code block (```java ```). \n"
            "If not, continue analyzing and searching for relevant context and dependencies until you are ready."
        )]}
    if "import" not in test_case or "class" not in test_case:
        return {"messages": [HumanMessage(
            "The test case is not runnable. Make sure it includes the necessary package, import sentence and class definitions")]}
    class_pattern = r'public class\s+(\w+)\s*(?:\{|\s+extends|\s+implements|$)'
    matches = re.findall(class_pattern, test_case)
    if not matches:
        return {"messages": [HumanMessage(
            "The generated test must declare a public class (e.g., `public class SampleTest { ... }`). "
            "Please wrap your test methods inside such a class and resend the complete snippet."
        )]}
    test_class_name = matches[0]
    return {"test_case": envServer.simple_fix(test_case), "test_class_name": test_class_name}


def extractChecker(state: GeneratorState):
    """
    **分支判断**
    根据‘代码提取’节点的输出结果，判断下一步的执行路径
    如果无法提取代码或代码缺少信息，则返回‘测试用例生成’节点
    否则进入‘编译’节点
    """
    test_case = state["test_case"]
    if test_case:
        return "compilation"
    return "testMethodGenerator"


def compilation(state: GeneratorState):
    """
    **状态节点**：编译
    编译测试用例，返回编译结果
    """
    envServer = state["envServer"]
    assert (state["test_case"] is not None)
    result = envServer.run_compile_test(state["test_case"], state["test_class_name"])
    compile_result = True if result["result"] == "Success" else False
    return {"test_result": result["result"], "test_report": str(result["output"]), "compile_result": compile_result}


def compilationChecker(state: GeneratorState):
    """
    **分支判断**
    根据‘编译’节点的输出结果，判断下一步的执行路径
    如果编译成功，则进入‘执行’节点
    否则进入‘反馈迭代’节点
    """
    if state["test_result"] == "Success":
        return "execution"
    return "feedbackIteration"


def execution(state: GeneratorState):
    """
    **状态节点**：执行
    执行测试用例，返回执行结果
    """
    envServer = state["envServer"]
    assert (state["test_case"] is not None)
    result = envServer.run_execute_test(state["test_case"], state["test_class_name"])
    execute_result = True if result["result"] == "Success" else False
    return {"test_result": result["result"], "test_report": str(result["output"]), "execute_result": execute_result}


def executionChecker(state: GeneratorState):
    """
    **分支判断**
    根据‘执行’节点的输出结果，判断下一步的执行路径
    如果执行成功，则进入‘报告生成’节点
    否则进入‘执行回顾’节点
    """
    if state["test_result"] == "Success":
        return "reportGenerator"
    return "executionReview"


def feedbackIteration(state: GeneratorState):
    """
    **状态节点**：反馈迭代
    对编译或执行失败的情况进行反馈，根据prompt模版生成反馈信息，返回反馈信息
    """
    test_result = state["test_result"]
    test_report = state["test_report"]
    if test_result == "Syntax Error":
        return {"feedback_times": state["feedback_times"] + 1,
                "messages": [HumanMessage("The test case contains syntax errors. Please fix the errors and try again."
                                          "You can use tools to help you find and fix the errors."
                                          f"tools: {', '.join([tool.name for tool in tools])}"
                                          "The error message is as follows: \n" + test_report)]}
    elif test_result == "Compile Error":
        return {"feedback_times": state["feedback_times"] + 1,
                "messages": [HumanMessage("The test case failed to compile. Please fix the errors and try again."
                                          "You can use tools to help you find and fix the errors."
                                          f"tools: {', '.join([tool.name for tool in tools])}"
                                          "The error message is as follows: \n" + test_report)]}
    elif test_result == "Execute Error":
        return {"feedback_times": state["feedback_times"] + 1,
                "messages": [HumanMessage("The test case failed to execute. Please fix the errors and try again."
                                          "You can use tools to help you find and fix the errors."
                                          f"tools: {', '.join([tool.name for tool in tools])}"
                                          "The error message is as follows: \n" + test_report)]}


def feedbackIterationChecker(state: GeneratorState):
    """
    **分支判断**
    根据‘反馈迭代’节点的输出结果，判断下一步的执行路径
    如果反馈次数未达到最大反馈次数，则返回‘测试用例生成’节点
    否则进入‘报告生成’节点（不论是否成功生成测试用例）
    """
    feedback_times = state["feedback_times"]
    max_feedback_times = state["max_feedback_times"]
    if feedback_times < max_feedback_times:
        return "testMethodGenerator"
    return "reportGenerator"


def executionReview(state: GeneratorState):
    """
    **状态节点**：执行回顾
    对执行失败的测试用例以及执行信息进行反思，调用LLM判断是方法的问题还是测试用例的问题
    返回反思结果，find_bug为True表示是方法的问题，为False表示是测试用例的问题
    """

    def extract_review(content: str):
        match = re.search(r'```json(.*?)```', content, re.DOTALL)
        return json.loads(match.group(1)) if match else None

    assert state["test_result"] == "Execute Error"
    execution_report = state["test_report"]
    prompt = Execution_Review_Template.invoke({"execution_report": execution_report,
                                               "test_specification": json.dumps(state["requirement"], indent=4)})
    valid_prompt = prompt.to_messages()
    result = llm.invoke(valid_prompt)
    review = extract_review(result.content)
    if review is None:
        raise ValueError("Invalid review result.")
    if review["issue"] == "method_bug":
        find_bug = True
    elif review["issue"] == "test_case_error":
        find_bug = False
    else:
        find_bug = False
    valid_prompt.append(result)
    return {"messages": valid_prompt, "find_bug": find_bug, "bug_report": review}


def reviewChecker(state: GeneratorState):
    """
    **分支判断**
    根据‘执行回顾’节点的输出结果，判断下一步的执行路径
    如果find_bug为True，则认为测试用例正确，进入‘报告生成’节点
    否则进入‘反馈迭代’节点
    """
    find_bug = state["find_bug"]
    if find_bug:
        return "reportGenerator"
    return "feedbackIteration"


def reportGenerator(state: GeneratorState):
    """
    **状态节点**：报告生成
    生成阶段结束，进入报告生成阶段，返回测试用例
    """
    test_case = state["test_case"]
    return {"test_case": test_case}


def coverage_report(state: GeneratorState):
    """
    **状态节点**：覆盖率报告
    运行覆盖率测试，返回覆盖率报告
    """
    envServer = state["envServer"]
    if not state["compile_result"]:
        ret = {"result": "Compile Error", "output": "Compile Error"}
        return {"coverage_report": ret}
    full_class_name = state["full_method_name"].rsplit('.', 1)[0]  # 从右侧分割，提取full_class_name
    result = envServer.run_coverage_test(state["test_case"], state["test_class_name"], state["package_name"],
                                         full_class_name, state["method_signature"], state["start_line"],
                                         state["end_line"])
    return {"coverage_report": result}


def mutation_report(state: GeneratorState):
    """
    **状态节点**：变异测试报告
    运行变异测试，返回变异测试报告
    """
    envServer = state["envServer"]
    if not state["execute_result"] or state["test_result"] == "Compile Error":
        ret = {"result": "Execute Error", "output": "Execute Error"}
        return {"mutation_report": ret}
    class_name = state["full_method_name"].rsplit('.', 2)[-2]  # 从右侧分割，提取full_class_name
    # parent_class_name = class_name.replace("$", "\$")
    result = envServer.run_mutation_test(state["package_name"], class_name, state["test_case"],
                                         state["test_class_name"],
                                         state["start_line"], state["end_line"])
    return {"mutation_report": result}


def add_testcase_to_CKG(state: GeneratorState):
    """
    **状态节点**：添加测试用例到CKG
    将生成的测试用例添加到CKG中
    """
    if state["generate_or_evaluation"] == "evaluation":
        return {"messages": [HumanMessage("Evaluation mode, no need to add test case to CKG.")]}
    envServer = state["envServer"]
    test_case = state["test_case"]
    test_class_name = "_" + str(envServer.number) + "_" + state["test_class_name"]
    envServer.number += 1
    # 这里需要修改test_case的类名，避免与后续的同名测试类冲突
    test_case = test_case.replace(state["test_class_name"], test_class_name)
    ret = envServer.add_test_to_CKG(test_case, test_class_name)
    if ret["result"] == "Error":
        return {"messages": [HumanMessage("Failed to add test case to CKG.")]}
    focal_clazz_name = state["class_name"]
    focal_method_fq_name = state["full_method_name"]
    index = focal_method_fq_name.find(focal_clazz_name)
    if index != -1:
        focal_clazz_fq_name = focal_method_fq_name[:index + len(focal_clazz_name)]
    else:
        focal_clazz_fq_name = focal_clazz_name
    find_bug = state["find_bug"]
    bug_report = state["bug_report"]
    method_signature = state["method_signature"]
    requirement = state["requirement"]
    test_report = state["test_result"]  # "Success" or "Execute Error" or "Compile Error" or "Syntax Error"
    coverage_rate = state["coverage_report"]["output"]["line_coverage"] if state["coverage_report"][
                                                                               "result"] == "Success" else "0"
    coverage_lines = state["coverage_report"]["output"]["covered_lines"] if state["coverage_report"][
                                                                                "result"] == "Success" else []
    mutation_score = state["mutation_report"]["output"]["mutation_score"] if state["mutation_report"][
                                                                                 "result"] == "Success" else "0"
    mutants = state["mutation_report"]["output"]["filtered_mutations"] if state["mutation_report"][
                                                                              "result"] == "Success" else {}
    result = graph_retriever.update_test_class(test_class_name, focal_clazz_fq_name, focal_method_fq_name,
                                               method_signature,
                                               test_report, coverage_rate, coverage_lines, mutation_score,
                                               json.dumps(mutants),
                                               find_bug, json.dumps(bug_report), json.dumps(requirement))
    return {"messages": [HumanMessage("Test case added to CKG successfully.")], "test_case": test_case,
            "test_class_name": test_class_name}


##########################################################################
############################## 进行状态图的构建 #############################
##########################################################################


generator_graph = StateGraph(GeneratorState)

generator_graph.add_node("init", generator_init_state)
generator_graph.add_node("testMethodGenerator", testMethodGenerator,
                         retry=RetryPolicy(max_attempts=3, retry_on=ValueError))
generator_graph.add_node("tools", call_tool)
generator_graph.add_node("codeExtractor", codeExtractor)
generator_graph.add_node("compilation", compilation)
generator_graph.add_node("execution", execution)
generator_graph.add_node("feedbackIteration", feedbackIteration)
generator_graph.add_node("reportGenerator", reportGenerator)
generator_graph.add_node("coverageReport", coverage_report)
generator_graph.add_node("mutationReport", mutation_report)
generator_graph.add_node("executionReview", executionReview, retry=RetryPolicy(max_attempts=3, retry_on=ValueError))
generator_graph.add_node("addTestCaseToCKG", add_testcase_to_CKG)

generator_graph.add_edge(START, "init")
generator_graph.add_edge("init", "testMethodGenerator")
generator_graph.add_conditional_edges("testMethodGenerator", generatorChecker,
                                      {"tools": "tools", "codeExtractor": "codeExtractor"})
generator_graph.add_edge("tools", "testMethodGenerator")
generator_graph.add_conditional_edges("codeExtractor", extractChecker,
                                      {"testMethodGenerator": "testMethodGenerator",
                                       "compilation": "compilation"})
generator_graph.add_conditional_edges("compilation", compilationChecker,
                                      {"execution": "execution", "feedbackIteration": "feedbackIteration"})
generator_graph.add_conditional_edges("execution", executionChecker,
                                      {"reportGenerator": "reportGenerator", "executionReview": "executionReview"})
generator_graph.add_conditional_edges("feedbackIteration", feedbackIterationChecker,
                                      {"testMethodGenerator": "testMethodGenerator",
                                       "reportGenerator": "reportGenerator"})
generator_graph.add_conditional_edges("executionReview", reviewChecker,
                                      {"feedbackIteration": "feedbackIteration", "reportGenerator": "reportGenerator"})
generator_graph.add_edge("reportGenerator", "coverageReport")
generator_graph.add_edge("coverageReport", "mutationReport")
generator_graph.add_edge("mutationReport", "addTestCaseToCKG")
generator_graph.add_edge("addTestCaseToCKG", END)

generator_graph = generator_graph.compile()

if __name__ == "__main__":
    try:
        # 生成图片数据
        image_data = generator_graph.get_graph().draw_mermaid_png()

        # 保存为本地文件
        with open("Generator.png", "wb") as f:
            f.write(image_data)
    except Exception as e:
        print(e)
        print("Failed to generate the image.")
    print("Graph compiled successfully.")
